#[burn_tensor_testgen::testgen(ad_conv_transpose3d)]
mod tests {
    use super::*;
    use burn_tensor::{Shape, Tolerance, module::conv_transpose3d, ops::ConvTransposeOptions};

    #[test]
    fn test_conv_transpose3d_basic() {
        let test = ConvTranspose3dTestCase {
            batch_size: 2,
            channels: [2, 2],
            kernel_size: [3, 3, 3],
            padding: [0, 0, 0],
            padding_out: [0, 0, 0],
            stride: [1, 1, 1],
            dilation: [1, 1, 1],
            groups: 1,
            size: [4, 4, 4],
        };
        let device = Default::default();
        let grads = Grads {
            x: TestTensor::from_floats(
                [
                    [
                        [
                            [
                                [13.250001, 13.250001, 13.250001, 13.250001],
                                [13.250001, 13.250001, 13.250001, 13.250001],
                                [13.250001, 13.250001, 13.250001, 13.250001],
                                [13.250001, 13.250001, 13.250001, 13.250001],
                            ],
                            [
                                [13.250001, 13.250001, 13.250001, 13.250001],
                                [13.250001, 13.250001, 13.250001, 13.250001],
                                [13.250001, 13.250001, 13.250001, 13.250001],
                                [13.250001, 13.250001, 13.250001, 13.250001],
                            ],
                            [
                                [13.250001, 13.250001, 13.250001, 13.250001],
                                [13.250001, 13.250001, 13.250001, 13.250001],
                                [13.250001, 13.250001, 13.250001, 13.250001],
                                [13.250001, 13.250001, 13.250001, 13.250001],
                            ],
                            [
                                [13.250001, 13.250001, 13.250001, 13.250001],
                                [13.250001, 13.250001, 13.250001, 13.250001],
                                [13.250001, 13.250001, 13.250001, 13.250001],
                                [13.250001, 13.250001, 13.250001, 13.250001],
                            ],
                        ],
                        [
                            [
                                [40.249992, 40.249992, 40.249992, 40.249992],
                                [40.249992, 40.249992, 40.249992, 40.249992],
                                [40.249992, 40.249992, 40.249992, 40.249992],
                                [40.249992, 40.249992, 40.249992, 40.249992],
                            ],
                            [
                                [40.249992, 40.249992, 40.249992, 40.249992],
                                [40.249992, 40.249992, 40.249992, 40.249992],
                                [40.249992, 40.249992, 40.249992, 40.249992],
                                [40.249992, 40.249992, 40.249992, 40.249992],
                            ],
                            [
                                [40.249992, 40.249992, 40.249992, 40.249992],
                                [40.249992, 40.249992, 40.249992, 40.249992],
                                [40.249992, 40.249992, 40.249992, 40.249992],
                                [40.249992, 40.249992, 40.249992, 40.249992],
                            ],
                            [
                                [40.249992, 40.249992, 40.249992, 40.249992],
                                [40.249992, 40.249992, 40.249992, 40.249992],
                                [40.249992, 40.249992, 40.249992, 40.249992],
                                [40.249992, 40.249992, 40.249992, 40.249992],
                            ],
                        ],
                    ],
                    [
                        [
                            [
                                [13.250001, 13.250001, 13.250001, 13.250001],
                                [13.250001, 13.250001, 13.250001, 13.250001],
                                [13.250001, 13.250001, 13.250001, 13.250001],
                                [13.250001, 13.250001, 13.250001, 13.250001],
                            ],
                            [
                                [13.250001, 13.250001, 13.250001, 13.250001],
                                [13.250001, 13.250001, 13.250001, 13.250001],
                                [13.250001, 13.250001, 13.250001, 13.250001],
                                [13.250001, 13.250001, 13.250001, 13.250001],
                            ],
                            [
                                [13.250001, 13.250001, 13.250001, 13.250001],
                                [13.250001, 13.250001, 13.250001, 13.250001],
                                [13.250001, 13.250001, 13.250001, 13.250001],
                                [13.250001, 13.250001, 13.250001, 13.250001],
                            ],
                            [
                                [13.250001, 13.250001, 13.250001, 13.250001],
                                [13.250001, 13.250001, 13.250001, 13.250001],
                                [13.250001, 13.250001, 13.250001, 13.250001],
                                [13.250001, 13.250001, 13.250001, 13.250001],
                            ],
                        ],
                        [
                            [
                                [40.249992, 40.249992, 40.249992, 40.249992],
                                [40.249992, 40.249992, 40.249992, 40.249992],
                                [40.249992, 40.249992, 40.249992, 40.249992],
                                [40.249992, 40.249992, 40.249992, 40.249992],
                            ],
                            [
                                [40.249992, 40.249992, 40.249992, 40.249992],
                                [40.249992, 40.249992, 40.249992, 40.249992],
                                [40.249992, 40.249992, 40.249992, 40.249992],
                                [40.249992, 40.249992, 40.249992, 40.249992],
                            ],
                            [
                                [40.249992, 40.249992, 40.249992, 40.249992],
                                [40.249992, 40.249992, 40.249992, 40.249992],
                                [40.249992, 40.249992, 40.249992, 40.249992],
                                [40.249992, 40.249992, 40.249992, 40.249992],
                            ],
                            [
                                [40.249992, 40.249992, 40.249992, 40.249992],
                                [40.249992, 40.249992, 40.249992, 40.249992],
                                [40.249992, 40.249992, 40.249992, 40.249992],
                                [40.249992, 40.249992, 40.249992, 40.249992],
                            ],
                        ],
                    ],
                ],
                &device,
            ),
            weight: TestTensor::from_floats(
                [
                    [
                        [
                            [
                                [47.750000, 47.750000, 47.750000],
                                [47.750000, 47.750000, 47.750000],
                                [47.750000, 47.750000, 47.750000],
                            ],
                            [
                                [47.750000, 47.750000, 47.750000],
                                [47.750000, 47.750000, 47.750000],
                                [47.750000, 47.750000, 47.750000],
                            ],
                            [
                                [47.750000, 47.750000, 47.750000],
                                [47.750000, 47.750000, 47.750000],
                                [47.750000, 47.750000, 47.750000],
                            ],
                        ],
                        [
                            [
                                [47.750000, 47.750000, 47.750000],
                                [47.750000, 47.750000, 47.750000],
                                [47.750000, 47.750000, 47.750000],
                            ],
                            [
                                [47.750000, 47.750000, 47.750000],
                                [47.750000, 47.750000, 47.750000],
                                [47.750000, 47.750000, 47.750000],
                            ],
                            [
                                [47.750000, 47.750000, 47.750000],
                                [47.750000, 47.750000, 47.750000],
                                [47.750000, 47.750000, 47.750000],
                            ],
                        ],
                    ],
                    [
                        [
                            [
                                [79.750000, 79.750000, 79.750000],
                                [79.750000, 79.750000, 79.750000],
                                [79.750000, 79.750000, 79.750000],
                            ],
                            [
                                [79.750000, 79.750000, 79.750000],
                                [79.750000, 79.750000, 79.750000],
                                [79.750000, 79.750000, 79.750000],
                            ],
                            [
                                [79.750000, 79.750000, 79.750000],
                                [79.750000, 79.750000, 79.750000],
                                [79.750000, 79.750000, 79.750000],
                            ],
                        ],
                        [
                            [
                                [79.750000, 79.750000, 79.750000],
                                [79.750000, 79.750000, 79.750000],
                                [79.750000, 79.750000, 79.750000],
                            ],
                            [
                                [79.750000, 79.750000, 79.750000],
                                [79.750000, 79.750000, 79.750000],
                                [79.750000, 79.750000, 79.750000],
                            ],
                            [
                                [79.750000, 79.750000, 79.750000],
                                [79.750000, 79.750000, 79.750000],
                                [79.750000, 79.750000, 79.750000],
                            ],
                        ],
                    ],
                ],
                &device,
            ),
            bias: TestTensor::from_floats([432., 432.], &device),
        };
        test.assert_grads(grads);
    }

    #[test]
    fn test_conv_transpose3d_complex_groups() {
        let test = ConvTranspose3dTestCase {
            batch_size: 1,
            channels: [4, 2],
            kernel_size: [2, 3, 4],
            padding: [1, 2, 3],
            padding_out: [1, 2, 3],
            stride: [2, 3, 4],
            dilation: [1, 2, 3],
            groups: 2,
            size: [6, 6, 6],
        };
        let device = Default::default();
        let grads = Grads {
            x: TestTensor::from_floats(
                [[
                    [
                        [
                            [1.250000, 1.625000, 1.625000, 1.625000, 1.625000, 1.625000],
                            [1.687500, 2.187500, 2.187500, 2.187500, 2.187500, 2.187500],
                            [1.687500, 2.187500, 2.187500, 2.187500, 2.187500, 2.187500],
                            [1.687500, 2.187500, 2.187500, 2.187500, 2.187500, 2.187500],
                            [1.687500, 2.187500, 2.187500, 2.187500, 2.187500, 2.187500],
                            [1.687500, 2.187500, 2.187500, 2.187500, 2.187500, 2.187500],
                        ],
                        [
                            [1.750000, 2.250000, 2.250000, 2.250000, 2.250000, 2.250000],
                            [2.250000, 2.875000, 2.875000, 2.875000, 2.875000, 2.875000],
                            [2.250000, 2.875000, 2.875000, 2.875000, 2.875000, 2.875000],
                            [2.250000, 2.875000, 2.875000, 2.875000, 2.875000, 2.875000],
                            [2.250000, 2.875000, 2.875000, 2.875000, 2.875000, 2.875000],
                            [2.250000, 2.875000, 2.875000, 2.875000, 2.875000, 2.875000],
                        ],
                        [
                            [1.750000, 2.250000, 2.250000, 2.250000, 2.250000, 2.250000],
                            [2.250000, 2.875000, 2.875000, 2.875000, 2.875000, 2.875000],
                            [2.250000, 2.875000, 2.875000, 2.875000, 2.875000, 2.875000],
                            [2.250000, 2.875000, 2.875000, 2.875000, 2.875000, 2.875000],
                            [2.250000, 2.875000, 2.875000, 2.875000, 2.875000, 2.875000],
                            [2.250000, 2.875000, 2.875000, 2.875000, 2.875000, 2.875000],
                        ],
                        [
                            [1.750000, 2.250000, 2.250000, 2.250000, 2.250000, 2.250000],
                            [2.250000, 2.875000, 2.875000, 2.875000, 2.875000, 2.875000],
                            [2.250000, 2.875000, 2.875000, 2.875000, 2.875000, 2.875000],
                            [2.250000, 2.875000, 2.875000, 2.875000, 2.875000, 2.875000],
                            [2.250000, 2.875000, 2.875000, 2.875000, 2.875000, 2.875000],
                            [2.250000, 2.875000, 2.875000, 2.875000, 2.875000, 2.875000],
                        ],
                        [
                            [1.750000, 2.250000, 2.250000, 2.250000, 2.250000, 2.250000],
                            [2.250000, 2.875000, 2.875000, 2.875000, 2.875000, 2.875000],
                            [2.250000, 2.875000, 2.875000, 2.875000, 2.875000, 2.875000],
                            [2.250000, 2.875000, 2.875000, 2.875000, 2.875000, 2.875000],
                            [2.250000, 2.875000, 2.875000, 2.875000, 2.875000, 2.875000],
                            [2.250000, 2.875000, 2.875000, 2.875000, 2.875000, 2.875000],
                        ],
                        [
                            [1.750000, 2.250000, 2.250000, 2.250000, 2.250000, 2.250000],
                            [2.250000, 2.875000, 2.875000, 2.875000, 2.875000, 2.875000],
                            [2.250000, 2.875000, 2.875000, 2.875000, 2.875000, 2.875000],
                            [2.250000, 2.875000, 2.875000, 2.875000, 2.875000, 2.875000],
                            [2.250000, 2.875000, 2.875000, 2.875000, 2.875000, 2.875000],
                            [2.250000, 2.875000, 2.875000, 2.875000, 2.875000, 2.875000],
                        ],
                    ],
                    [
                        [
                            [2.750000, 3.625000, 3.625000, 3.625000, 3.625000, 3.625000],
                            [3.937500, 5.187500, 5.187500, 5.187500, 5.187500, 5.187500],
                            [3.937500, 5.187500, 5.187500, 5.187500, 5.187500, 5.187500],
                            [3.937500, 5.187500, 5.187500, 5.187500, 5.187500, 5.187500],
                            [3.937500, 5.187500, 5.187500, 5.187500, 5.187500, 5.187500],
                            [3.937500, 5.187500, 5.187500, 5.187500, 5.187500, 5.187500],
                        ],
                        [
                            [4.750000, 6.250000, 6.250000, 6.250000, 6.250000, 6.250000],
                            [6.750000, 8.875000, 8.875000, 8.875000, 8.875000, 8.875000],
                            [6.750000, 8.875000, 8.875000, 8.875000, 8.875000, 8.875000],
                            [6.750000, 8.875000, 8.875000, 8.875000, 8.875000, 8.875000],
                            [6.750000, 8.875000, 8.875000, 8.875000, 8.875000, 8.875000],
                            [6.750000, 8.875000, 8.875000, 8.875000, 8.875000, 8.875000],
                        ],
                        [
                            [4.750000, 6.250000, 6.250000, 6.250000, 6.250000, 6.250000],
                            [6.750000, 8.875000, 8.875000, 8.875000, 8.875000, 8.875000],
                            [6.750000, 8.875000, 8.875000, 8.875000, 8.875000, 8.875000],
                            [6.750000, 8.875000, 8.875000, 8.875000, 8.875000, 8.875000],
                            [6.750000, 8.875000, 8.875000, 8.875000, 8.875000, 8.875000],
                            [6.750000, 8.875000, 8.875000, 8.875000, 8.875000, 8.875000],
                        ],
                        [
                            [4.750000, 6.250000, 6.250000, 6.250000, 6.250000, 6.250000],
                            [6.750000, 8.875000, 8.875000, 8.875000, 8.875000, 8.875000],
                            [6.750000, 8.875000, 8.875000, 8.875000, 8.875000, 8.875000],
                            [6.750000, 8.875000, 8.875000, 8.875000, 8.875000, 8.875000],
                            [6.750000, 8.875000, 8.875000, 8.875000, 8.875000, 8.875000],
                            [6.750000, 8.875000, 8.875000, 8.875000, 8.875000, 8.875000],
                        ],
                        [
                            [4.750000, 6.250000, 6.250000, 6.250000, 6.250000, 6.250000],
                            [6.750000, 8.875000, 8.875000, 8.875000, 8.875000, 8.875000],
                            [6.750000, 8.875000, 8.875000, 8.875000, 8.875000, 8.875000],
                            [6.750000, 8.875000, 8.875000, 8.875000, 8.875000, 8.875000],
                            [6.750000, 8.875000, 8.875000, 8.875000, 8.875000, 8.875000],
                            [6.750000, 8.875000, 8.875000, 8.875000, 8.875000, 8.875000],
                        ],
                        [
                            [4.750000, 6.250000, 6.250000, 6.250000, 6.250000, 6.250000],
                            [6.750000, 8.875000, 8.875000, 8.875000, 8.875000, 8.875000],
                            [6.750000, 8.875000, 8.875000, 8.875000, 8.875000, 8.875000],
                            [6.750000, 8.875000, 8.875000, 8.875000, 8.875000, 8.875000],
                            [6.750000, 8.875000, 8.875000, 8.875000, 8.875000, 8.875000],
                            [6.750000, 8.875000, 8.875000, 8.875000, 8.875000, 8.875000],
                        ],
                    ],
                    [
                        [
                            [4.250000, 5.625000, 5.625000, 5.625000, 5.625000, 5.625000],
                            [6.187500, 8.187500, 8.187500, 8.187500, 8.187500, 8.187500],
                            [6.187500, 8.187500, 8.187500, 8.187500, 8.187500, 8.187500],
                            [6.187500, 8.187500, 8.187500, 8.187500, 8.187500, 8.187500],
                            [6.187500, 8.187500, 8.187500, 8.187500, 8.187500, 8.187500],
                            [6.187500, 8.187500, 8.187500, 8.187500, 8.187500, 8.187500],
                        ],
                        [
                            [
                                7.750000, 10.250000, 10.250000, 10.250000, 10.250000, 10.250000,
                            ],
                            [
                                11.250000, 14.875000, 14.875000, 14.875000, 14.875000, 14.875000,
                            ],
                            [
                                11.250000, 14.875000, 14.875000, 14.875000, 14.875000, 14.875000,
                            ],
                            [
                                11.250000, 14.875000, 14.875000, 14.875000, 14.875000, 14.875000,
                            ],
                            [
                                11.250000, 14.875000, 14.875000, 14.875000, 14.875000, 14.875000,
                            ],
                            [
                                11.250000, 14.875000, 14.875000, 14.875000, 14.875000, 14.875000,
                            ],
                        ],
                        [
                            [
                                7.750000, 10.250000, 10.250000, 10.250000, 10.250000, 10.250000,
                            ],
                            [
                                11.250000, 14.875000, 14.875000, 14.875000, 14.875000, 14.875000,
                            ],
                            [
                                11.250000, 14.875000, 14.875000, 14.875000, 14.875000, 14.875000,
                            ],
                            [
                                11.250000, 14.875000, 14.875000, 14.875000, 14.875000, 14.875000,
                            ],
                            [
                                11.250000, 14.875000, 14.875000, 14.875000, 14.875000, 14.875000,
                            ],
                            [
                                11.250000, 14.875000, 14.875000, 14.875000, 14.875000, 14.875000,
                            ],
                        ],
                        [
                            [
                                7.750000, 10.250000, 10.250000, 10.250000, 10.250000, 10.250000,
                            ],
                            [
                                11.250000, 14.875000, 14.875000, 14.875000, 14.875000, 14.875000,
                            ],
                            [
                                11.250000, 14.875000, 14.875000, 14.875000, 14.875000, 14.875000,
                            ],
                            [
                                11.250000, 14.875000, 14.875000, 14.875000, 14.875000, 14.875000,
                            ],
                            [
                                11.250000, 14.875000, 14.875000, 14.875000, 14.875000, 14.875000,
                            ],
                            [
                                11.250000, 14.875000, 14.875000, 14.875000, 14.875000, 14.875000,
                            ],
                        ],
                        [
                            [
                                7.750000, 10.250000, 10.250000, 10.250000, 10.250000, 10.250000,
                            ],
                            [
                                11.250000, 14.875000, 14.875000, 14.875000, 14.875000, 14.875000,
                            ],
                            [
                                11.250000, 14.875000, 14.875000, 14.875000, 14.875000, 14.875000,
                            ],
                            [
                                11.250000, 14.875000, 14.875000, 14.875000, 14.875000, 14.875000,
                            ],
                            [
                                11.250000, 14.875000, 14.875000, 14.875000, 14.875000, 14.875000,
                            ],
                            [
                                11.250000, 14.875000, 14.875000, 14.875000, 14.875000, 14.875000,
                            ],
                        ],
                        [
                            [
                                7.750000, 10.250000, 10.250000, 10.250000, 10.250000, 10.250000,
                            ],
                            [
                                11.250000, 14.875000, 14.875000, 14.875000, 14.875000, 14.875000,
                            ],
                            [
                                11.250000, 14.875000, 14.875000, 14.875000, 14.875000, 14.875000,
                            ],
                            [
                                11.250000, 14.875000, 14.875000, 14.875000, 14.875000, 14.875000,
                            ],
                            [
                                11.250000, 14.875000, 14.875000, 14.875000, 14.875000, 14.875000,
                            ],
                            [
                                11.250000, 14.875000, 14.875000, 14.875000, 14.875000, 14.875000,
                            ],
                        ],
                    ],
                    [
                        [
                            [5.750000, 7.625000, 7.625000, 7.625000, 7.625000, 7.625000],
                            [
                                8.437500, 11.187500, 11.187500, 11.187500, 11.187500, 11.187500,
                            ],
                            [
                                8.437500, 11.187500, 11.187500, 11.187500, 11.187500, 11.187500,
                            ],
                            [
                                8.437500, 11.187500, 11.187500, 11.187500, 11.187500, 11.187500,
                            ],
                            [
                                8.437500, 11.187500, 11.187500, 11.187500, 11.187500, 11.187500,
                            ],
                            [
                                8.437500, 11.187500, 11.187500, 11.187500, 11.187500, 11.187500,
                            ],
                        ],
                        [
                            [
                                10.750000, 14.250000, 14.250000, 14.250000, 14.250000, 14.250000,
                            ],
                            [
                                15.750000, 20.875000, 20.875000, 20.875000, 20.875000, 20.875000,
                            ],
                            [
                                15.750000, 20.875000, 20.875000, 20.875000, 20.875000, 20.875000,
                            ],
                            [
                                15.750000, 20.875000, 20.875000, 20.875000, 20.875000, 20.875000,
                            ],
                            [
                                15.750000, 20.875000, 20.875000, 20.875000, 20.875000, 20.875000,
                            ],
                            [
                                15.750000, 20.875000, 20.875000, 20.875000, 20.875000, 20.875000,
                            ],
                        ],
                        [
                            [
                                10.750000, 14.250000, 14.250000, 14.250000, 14.250000, 14.250000,
                            ],
                            [
                                15.750000, 20.875000, 20.875000, 20.875000, 20.875000, 20.875000,
                            ],
                            [
                                15.750000, 20.875000, 20.875000, 20.875000, 20.875000, 20.875000,
                            ],
                            [
                                15.750000, 20.875000, 20.875000, 20.875000, 20.875000, 20.875000,
                            ],
                            [
                                15.750000, 20.875000, 20.875000, 20.875000, 20.875000, 20.875000,
                            ],
                            [
                                15.750000, 20.875000, 20.875000, 20.875000, 20.875000, 20.875000,
                            ],
                        ],
                        [
                            [
                                10.750000, 14.250000, 14.250000, 14.250000, 14.250000, 14.250000,
                            ],
                            [
                                15.750000, 20.875000, 20.875000, 20.875000, 20.875000, 20.875000,
                            ],
                            [
                                15.750000, 20.875000, 20.875000, 20.875000, 20.875000, 20.875000,
                            ],
                            [
                                15.750000, 20.875000, 20.875000, 20.875000, 20.875000, 20.875000,
                            ],
                            [
                                15.750000, 20.875000, 20.875000, 20.875000, 20.875000, 20.875000,
                            ],
                            [
                                15.750000, 20.875000, 20.875000, 20.875000, 20.875000, 20.875000,
                            ],
                        ],
                        [
                            [
                                10.750000, 14.250000, 14.250000, 14.250000, 14.250000, 14.250000,
                            ],
                            [
                                15.750000, 20.875000, 20.875000, 20.875000, 20.875000, 20.875000,
                            ],
                            [
                                15.750000, 20.875000, 20.875000, 20.875000, 20.875000, 20.875000,
                            ],
                            [
                                15.750000, 20.875000, 20.875000, 20.875000, 20.875000, 20.875000,
                            ],
                            [
                                15.750000, 20.875000, 20.875000, 20.875000, 20.875000, 20.875000,
                            ],
                            [
                                15.750000, 20.875000, 20.875000, 20.875000, 20.875000, 20.875000,
                            ],
                        ],
                        [
                            [
                                10.750000, 14.250000, 14.250000, 14.250000, 14.250000, 14.250000,
                            ],
                            [
                                15.750000, 20.875000, 20.875000, 20.875000, 20.875000, 20.875000,
                            ],
                            [
                                15.750000, 20.875000, 20.875000, 20.875000, 20.875000, 20.875000,
                            ],
                            [
                                15.750000, 20.875000, 20.875000, 20.875000, 20.875000, 20.875000,
                            ],
                            [
                                15.750000, 20.875000, 20.875000, 20.875000, 20.875000, 20.875000,
                            ],
                            [
                                15.750000, 20.875000, 20.875000, 20.875000, 20.875000, 20.875000,
                            ],
                        ],
                    ],
                ]],
                &device,
            ),
            weight: TestTensor::from_floats(
                [
                    [[
                        [
                            [18.663193, 22.309027, 22.309027, 22.309027],
                            [21.875000, 26.145834, 26.145834, 26.145834],
                            [21.875000, 26.145834, 26.145834, 26.145834],
                        ],
                        [
                            [19.270832, 23.020834, 23.020834, 23.020834],
                            [22.500000, 26.875002, 26.875002, 26.875002],
                            [22.500000, 26.875002, 26.875002, 26.875002],
                        ],
                    ]],
                    [[
                        [
                            [49.913193, 59.809029, 59.809029, 59.809029],
                            [59.375000, 71.145836, 71.145836, 71.145836],
                            [59.375000, 71.145836, 71.145836, 71.145836],
                        ],
                        [
                            [56.770836, 68.020836, 68.020836, 68.020836],
                            [67.500000, 80.875000, 80.875000, 80.875000],
                            [67.500000, 80.875000, 80.875000, 80.875000],
                        ],
                    ]],
                    [[
                        [
                            [81.163193, 97.309029, 97.309029, 97.309029],
                            [96.875000, 116.145828, 116.145828, 116.145828],
                            [96.875000, 116.145828, 116.145828, 116.145828],
                        ],
                        [
                            [94.270828, 113.020828, 113.020828, 113.020828],
                            [112.500000, 134.875000, 134.875000, 134.875000],
                            [112.500000, 134.875000, 134.875000, 134.875000],
                        ],
                    ]],
                    [[
                        [
                            [112.413200, 134.809021, 134.809021, 134.809021],
                            [134.375000, 161.145828, 161.145828, 161.145828],
                            [134.375000, 161.145828, 161.145828, 161.145828],
                        ],
                        [
                            [131.770844, 158.020828, 158.020828, 158.020828],
                            [157.500000, 188.875000, 188.875000, 188.875000],
                            [157.500000, 188.875000, 188.875000, 188.875000],
                        ],
                    ]],
                ],
                &device,
            ),
            bias: TestTensor::from_floats([5346., 5346.], &device),
        };
        test.assert_grads(grads);
    }

    struct ConvTranspose3dTestCase {
        batch_size: usize,
        channels: [usize; 2],
        kernel_size: [usize; 3],
        padding: [usize; 3],
        padding_out: [usize; 3],
        stride: [usize; 3],
        dilation: [usize; 3],
        groups: usize,
        size: [usize; 3],
    }

    struct Grads {
        x: TestTensor<5>,
        weight: TestTensor<5>,
        bias: TestTensor<1>,
    }

    impl ConvTranspose3dTestCase {
        fn assert_grads(self, expected_grads: Grads) {
            let shape_x = Shape::new([
                self.batch_size,
                self.channels[0],
                self.size[0],
                self.size[1],
                self.size[2],
            ]);
            let shape_weight = Shape::new([
                self.channels[0],
                self.channels[1] / self.groups,
                self.kernel_size[0],
                self.kernel_size[1],
                self.kernel_size[2],
            ]);
            let device = Default::default();
            let weight = TestAutodiffTensor::from_data(
                TestTensorInt::arange(0..shape_weight.num_elements() as i64, &device)
                    .reshape::<5, _>(shape_weight.clone())
                    .into_data(),
                &device,
            )
            .div_scalar(shape_weight.num_elements() as f32)
            .require_grad();
            let bias = TestAutodiffTensor::from_data(
                TestTensorInt::arange(0..self.channels[1] as i64, &device).into_data(),
                &device,
            )
            .require_grad();
            let x = TestAutodiffTensor::from_data(
                TestTensorInt::arange(0..shape_x.num_elements() as i64, &device)
                    .reshape::<5, _>(shape_x.clone())
                    .into_data(),
                &device,
            )
            .div_scalar(shape_x.num_elements() as f32)
            .require_grad();
            let output = conv_transpose3d(
                x.clone(),
                weight.clone(),
                Some(bias.clone()),
                ConvTransposeOptions::new(
                    self.stride,
                    self.padding,
                    self.padding_out,
                    self.dilation,
                    self.groups,
                ),
            );
            let grads = output.backward();

            // Assert
            let x_grad_actual = x.grad(&grads).unwrap();
            let weight_grad_actual = weight.grad(&grads).unwrap();
            let bias_grad_actual = bias.grad(&grads).unwrap();

            let tolerance = Tolerance::permissive();
            expected_grads
                .bias
                .to_data()
                .assert_approx_eq::<FloatType>(&bias_grad_actual.to_data(), tolerance);
            expected_grads
                .x
                .to_data()
                .assert_approx_eq::<FloatType>(&x_grad_actual.to_data(), tolerance);
            expected_grads
                .weight
                .to_data()
                .assert_approx_eq::<FloatType>(&weight_grad_actual.to_data(), tolerance);
        }
    }
}
